%% Clean up
clear variables

%% General
% Paths
filename    = 'Chains.xlsx';
datapath    = 'Output\';
path        = 'parameters_OJS\';
name        = 'replication';

load([path,'param_ini_',name,'.mat']);
load([path,'param_end_',name,'.mat']);
param=param_ini;

load([path,'aux_transition.mat']);
aux_old=aux;
clear aux

%% Set auxilary stuff
aux.CN      = 0.5;% CN share
aux.T       = 12*7;% time periods on months
aux.n_m     = 1600;% number of grid points
aux.N_T     = length(aux_old.theta);% nr grid points for IRF
aux.pen     = 10^(5);% penalty in HJB iteration

%% Create grid for timesteps
n=1;
scale=3.5;
tmp_n1=22*10^(-2).*aux.N_T;
tmp_t_grid1=(exp(scale.*linspace(0,1,n+(aux.N_T-tmp_n1)))-1)./(exp(scale)-1).*aux.T/2;
tmp_t_grid2=linspace(aux.T/2,aux.T,tmp_n1+1);
aux.t_vec=[tmp_t_grid1(1+n:end),tmp_t_grid2(2:end)];
clear tmp_*
aux.t_lam=aux.t_vec;

param.mr=0;
aux.p_t = 1-param.shock.*exp(param.mr.*aux.t_vec);

%% Initial guess for path for lambda
% theta=param_end.lambda.*ones(size(aux.t_lam));
theta=interp1(aux_old.t_vec,aux_old.theta,aux.t_vec,'linear');
theta((aux.t_vec>aux_old.t_vec(end)))=aux_old.theta(end);
theta((aux.t_vec<aux_old.t_vec(1)))=aux_old.theta(1);

%% Create grid for m
tmp=log(param_end.m_u)-log(param_end.m_l);
tmp1=aux.n_m/20;
tmp2=round(aux.n_m/2.4);
tmp3=round(aux.n_m/7);
a=linspace(log(param_end.m_l)-0.15,log(param_end.m_l)-0.03,tmp1)';
b=linspace(log(param_end.m_l)-0.03,log(param_end.m_l)+tmp/4,tmp2+1)';
c=linspace(log(param_end.m_l)+tmp/4,log(param_end.m_h)-tmp/4,tmp3+1)';
d=linspace(log(param_end.m_h)-tmp/4,log(param_end.m_h)+0.01,aux.n_m-tmp2-tmp3-2.*tmp1+1)';
e=linspace(log(param_end.m_h)+0.01,log(param_end.m_u)+0.15,tmp1+1)';
ln_m_HJB=[a;b(2:end);c(2:end);d(2:end);e(2:end)];
param.ln_m_HJB=ln_m_HJB;%linspace(log(param.m_l)-0.3,log(param.m_u)+0.3,aux.n_m)';

%% Create transition matrix 
T=speye(aux.n_m);
[T_p, T_l_p, T_u_p] = M_p(param.ln_m_HJB,ones(aux.n_m,1));
[T_pp] = M_pp(param.ln_m_HJB);
param.T=param.r.*T-(param.mu-param.sigma.^2./2).*T_p-param.sigma.^2./2.*T_pp;
param.T_lam=T-(1-param.alpha).*T_p;

%% Use analytical functions to calculcate value functions and distributions on grid 
J_end=J_fun(exp(param.ln_m_HJB),param_end);
J_end_analy = J_end;
J_ini=J_fun(exp(param.ln_m_HJB),param_ini);
J_ini_analy = J_ini;
Q_ini=q_fun(min(param.m_u,max(param.m_l, exp(ln_m_KFE([param.m_l,param.m_u],1,aux)))), param);
Q_ini=Q_ini./(Q_ini(1).*param.s+Q_ini(end)-Q_ini(1));
Q_ini_analy = Q_ini;
Q_end=q_fun(min(param_end.m_u,max(param_end.m_l,exp(ln_m_KFE([param_end.m_l,param_end.m_u],1,aux)))),param_end);
Q_end=Q_end./(Q_end(1).*param.s+Q_end(end)-Q_end(1));
Q_end_analy = Q_end;

%% Solve value functions and distributions on grid using finite difference
aux.dt=0.05;
aux.n=10^4;
disc_J=1;
J_old=J_ini;
while disc_J>10^(-13)
    J_ini   = HJB_t_OJS(J_ini,param,aux);
    disc_J  = max(abs(J_ini-J_old))
    J_old   = J_ini;
end
% calculate implied boundaries
param.m_l=interp1(J_old,exp(param.ln_m_HJB),0,'pchip');
param.m_h=interp1(J_old,exp(param.ln_m_HJB),param.c,'pchip');

x0= fzero(@ (x) delta_OJS(abs(x)+param.m_h,param), 10^(-2), optimset('display','off'));
param.m_u=param.m_h+abs(x0);
Q_ini = KFE_t([param.m_l,param.m_u],Q_ini,delta_fun(exp(ln_m_KFE([param.m_l,param.m_u],1,aux)),param_ini),param,aux);

param_end.ln_m_HJB=param.ln_m_HJB;
param_end.T=param.T;
param_end.T_lam=param.T_lam;

disc_J=1;
J_old=J_end;
while disc_J>10^(-13)
    J_end   = HJB_t_OJS(J_end,param_end,aux);
    disc_J  = max(abs(J_end-J_old))
    J_old   = J_end;
end

param_end.m_l=interp1(J_old,exp(param.ln_m_HJB),0,'pchip');
param_end.m_h=interp1(J_old,exp(param.ln_m_HJB),param_end.c,'pchip');

Q_end = KFE_t([param_end.m_l,param_end.m_u],Q_end,delta_fun(exp(ln_m_KFE([param_end.m_l,param_end.m_u],1,aux)),param_end),param_end,aux);

param.m_l_ini=param.m_l;
param.m_u_ini=param.m_u;
param.m_l_end=param_end.m_l;
param.m_u_end=param_end.m_u;
param.mu_base=param_ini.mu;

%% Check difference between analytical solutions and solutions on grid
figure(1)
plot(exp(param.ln_m_HJB),J_ini,'b')
hold on
plot(exp(param.ln_m_HJB),J_ini_analy,'k--')
plot(exp(param.ln_m_HJB),J_end,'r')
plot(exp(param.ln_m_HJB),J_end_analy,'g--')
hold off

figure(2)
plot(Q_ini,'b')
hold on
plot(Q_ini_analy,'k--')
plot(Q_end,'r')
plot(Q_end_analy,'g--')
hold off
m_grid_HJB=exp(param.ln_m_HJB);
m_grid_FPE=exp(ln_m_KFE([param.m_l,param.m_u],1,aux));
T = table(m_grid_HJB, J_ini_analy,J_ini,m_grid_FPE,Q_ini_analy,Q_ini);

writetable(T, [datapath,filename], 'Sheet', 'nume_acc_OJS', 'Range', 'A1')


%% Assign mean productivity
param.mean = (1-param.s*Q_end(1))*mean_t_fun(exp(1./(1-param.alpha).*ln_m_KFE([param_end.m_l,param_end.m_u],1,aux)), Q_end);
param.omega_0=param_end.omega_0;

%% Calculate the slugginsh adjustment
adj     = 1;
adj1    = 0.06;
adj1    = exp(-adj1.*aux.t_lam);
diff0   = 1;
tol     = 5*10^(-6);
dt      = [aux.t_vec(1),aux.t_vec(2:end)-aux.t_vec(1:end-1)];


%% Solve for the transition dynanmics; update lambda path based on excess demand
i=1;
while diff0>tol
	diff_t = transition_t_OJS(theta, J_end, Q_ini, param, aux, 1);
    theta_old = theta;

    mean  = diff_t.mean;
    diff = diff_t.excess;
    diff0 = max(abs(diff));
    if diff0>tol
        theta = theta + adj.*max(-0.5,min(0.5,diff)).*adj1;
        theta = max(theta,0.0001);
    end
    100*[diff(1:5), diff(end), diff0, diff_t.excess(end), 10^2.*sum(abs(diff))/length(aux.t_lam)]
    i = i + 1;
end

%% Save various objects along 

param.m_grid = linspace(min(param_end.m_l,param_ini.m_l)-0.01,max(param_end.m_u,param_ini.m_u)+0.01,200);

Q_start=Q_ini;
param_start=param_ini;

clear diff
aux.T_fig=12*5;
diff.t_grid= linspace(0,1,1001).*aux.T_fig;
diff.t_grid= diff.t_grid(2:end);
diff.u=interp1(aux.t_vec,diff_t.u,diff.t_grid,'linear','extrap');
diff.lambda=interp1(aux.t_vec,diff_t.lambda,diff.t_grid,'linear','extrap');
diff.ete=interp1(aux.t_vec,diff_t.ete,diff.t_grid,'linear','extrap');
diff.mean_m=interp1(aux.t_vec,diff_t.mean_m,diff.t_grid,'linear','extrap');
diff.hires_tot=interp1(aux.t_vec,diff_t.hires_tot,diff.t_grid,'linear','extrap');
diff.new_chains=interp1(aux.t_vec,diff_t.new_chains,diff.t_grid,'linear','extrap');

diff.m_l_t=interp1(aux.t_vec,diff_t.m_l_t,diff.t_grid,'linear','extrap');
diff.m_h_t=interp1(aux.t_vec,diff_t.m_h_t,diff.t_grid,'linear','extrap');
diff.m_u_t=interp1(aux.t_vec,diff_t.m_u_t,diff.t_grid,'linear','extrap');

%% Flows
u_ini = Q_start(1).*param.s;
u = [Q_start(1).*param.s, diff.u];

lambda=[param_start.lambda, diff.lambda];
ete_ini = jtj_num( delta_fun( exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), param_start), Q_start);
hires_ini = ete_ini.*(1-u(1))+ param_start.lambda.*u(1);
ete = [ ete_ini; ete_ini; diff.ete'];
hires = [hires_ini; hires_ini;  diff.hires_tot'];


m_l_t = [ param_start.m_l; param_start.m_l;  diff.m_l_t'];
m_h_t = [ param_start.m_h; param_start.m_h;  diff.m_h_t'];
m_u_t = [ param_start.m_u; param_start.m_u;  diff.m_u_t'];


G_r = interp1( exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), (Q_start-Q_start(1))./(Q_start(end)-Q_start(1)), param_start.m_h, 'linear');
new_chains_ini = param.s.*param_start.lambda.*G_r.*(1-u(1)) + u(1).*param_start.lambda;
new_chains = [new_chains_ini; new_chains_ini; diff.new_chains'];

eu = (u(2:end) - u(1:end-1))./[diff.t_grid(1), diff.t_grid(2:end)-diff.t_grid(1:end-1)]./(1-u(2:end)) +  diff.lambda.*u(2:end)./(1-u(2:end));
t_grid=[-12; -0.00001; diff.t_grid'];
u=[u(1),u]';
eu_ini = lambda(1).*u(1)./(1-u(1));
eu=[ eu_ini; eu_ini; eu'];
mean_m =  mean_t_fun(exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), Q_start);


%% time aggregated 
n_agg = 200;
t_grid_agg = [0:aux.T_fig];
t_grid_tmp = linspace(0,1,aux.T_fig*n_agg+1).*aux.T_fig;
lambda_tmp = interp1( aux.t_vec, diff_t.lambda, t_grid_tmp,'linear','extrap');
ete_tmp = interp1( aux.t_vec, diff_t.ete, t_grid_tmp,'linear','extrap');
hires_tmp = interp1( aux.t_vec, diff_t.hires_tot, t_grid_tmp,'linear','extrap');
new_chains_tmp = interp1( aux.t_vec, diff_t.new_chains', t_grid_tmp,'linear','extrap');

u_tmp = interp1( aux.t_vec, diff_t.u, t_grid_tmp,'linear','extrap');
u_tmp(1)   = Q_start(1).*param.s; % measure u(0) prior to shock to get eu right
u_agg      = interp1( aux.t_vec, diff_t.u, t_grid_agg,'linear','extrap');
u_agg(1)   = Q_start(1).*param.s; % measure u(0) prior to shock to get eu right

new_chains_agg = sum(reshape(new_chains_tmp(1:end-1)/n_agg, n_agg, aux.T_fig));
hires_agg = sum(reshape(hires_tmp(1:end-1)/n_agg, n_agg, aux.T_fig));
ete_agg = sum(reshape(ete_tmp(1:end-1).*(1-u_tmp(1:end-1))./n_agg, n_agg, aux.T_fig))./(1-u_agg(1:end-1));
lambda_agg = sum(reshape(lambda_tmp(1:end-1)/n_agg.*u_tmp(1:end-1), n_agg, aux.T_fig))./u_agg(1:end-1);
eu_tmp = (u_tmp(2:end) - u_tmp(1:end-1))./(t_grid_tmp(2:end) - t_grid_tmp(1:end-1))./(1-u_tmp(1:end-1)) +  lambda_tmp(1:end-1).*u_tmp(1:end-1)./(1-u_tmp(1:end-1));
eu_agg = sum(reshape(eu_tmp/n_agg.*(1-u_tmp(1:end-1)), n_agg, aux.T_fig)./(1-u_agg(1:end-1)));

varNames = {'Time_months', ...
    'Unemployment', 'UE', 'EU', 'ete', 'total_hires', 'new_chains'};

T = table([-12;t_grid_agg'], ...
    [u_ini,u_ini,u_agg(2:end)]', [ param_start.lambda, param_start.lambda, lambda_agg]', [eu_ini,eu_ini, eu_agg]', [ete_ini, ete_ini, ete_agg]', [hires_ini,hires_ini,hires_agg]' , [new_chains_ini,new_chains_ini,new_chains_agg]', ...
    'VariableNames', varNames);

writetable(T, [datapath,filename], 'Sheet', 'transition_OJS', 'Range', 'Q1')
   

varNames = {'Time_months', 'Productivity', ...
    'Output', 'Average_m', ...
    'Unemployment', 'UE', 'EU', 'ete',  'total_hires', 'new_chains', ... 'G_m_h', 
    'm_l', 'm_h', 'm_u'};

T = table(t_grid, [1; 1; (1-param.shock).*ones(size(diff.t_grid'))], ...
    [mean_m; mean_m; diff.mean_m']./param.alpha.*(1-u),[mean_m; mean_m; diff.mean_m'], ...
    u, [ lambda(1); lambda'], eu, ete, hires, new_chains, ...  
    m_l_t, m_h_t, m_u_t, ...
    'VariableNames', varNames);


writetable(T, [datapath,filename], 'Sheet', 'transition_OJS', 'Range', 'A1')

aux.diff_t=diff_t;
aux.theta=theta;
save([path,'aux_transition.mat'], 'aux');
